Skip to main content

About PyTorch Lightning

  • It is said to be the most actively developed PyTorch wrapper right now, so I gave it a try

What is PyTorch Lightning?

PyTorch Lightning is a Keras-like ML library for PyTorch that lets engineers handle the core training and validation logic while automating the rest. (It strikes a nice balance between usability and flexibility by not oversimplifying things like Keras does.)

fast.ai was created for educating people interested in deep learning, while Lightning, Ignite, and Catalyst were built for researchers who use ML. Lightning in particular was designed to improve research reproducibility and let researchers focus more on their research.

  • Since the Trainer and LightningModule are separated, you can focus on the LightningModule part.
pytorch-lightning

Lightning vs Ignite 💥

  • Lightning makes it clearer where things happen compared to Ignite.
  • With Lightning, you only need to focus on 9 functions

Reference code:

PyTorch-Lightning
class CoolModel(ptl.LightningModule):

def __init__(self):
super(CoolModel, self).__init__()
# not the best model...
self.l1 = torch.nn.Linear(28 * 28, 10)

def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))

def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'loss': F.cross_entropy(y_hat, y)(y_hat, y)}

def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)(y_hat, y)}

def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'avg_val_loss': avg_loss}

def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=0.02)]

@ptl.data_loader
def tng_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

@ptl.data_loader
def val_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

@ptl.data_loader
def test_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

Ignite
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)


def get_data_loaders(train_batch_size, val_batch_size):
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True),
batch_size=train_batch_size, shuffle=True)

val_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
batch_size=val_batch_size, shuffle=False)
return train_loader, val_loader


def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
model = Net()
device = 'cpu'

if torch.cuda.is_available():
device = 'cuda'

optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
evaluator = create_supervised_evaluator(model,
metrics={'accuracy': Accuracy(),
'nll': Loss(F.nll_loss)},
device=device)

desc = "ITERATION - loss: {:.2f}"
pbar = tqdm(
initial=0, leave=False, total=len(train_loader),
desc=desc.format(0)
)

@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
iter = (engine.state.iteration - 1) % len(train_loader) + 1

if iter % log_interval == 0:
pbar.desc = desc.format(engine.state.output)
pbar.update(log_interval)

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
pbar.refresh()
evaluator.run(train_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics['accuracy']
avg_nll = metrics['nll']
tqdm.write(
"Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, avg_accuracy, avg_nll)
)

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics['accuracy']
avg_nll = metrics['nll']
tqdm.write(
"Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, avg_accuracy, avg_nll))

pbar.n = pbar.last_print_n = 0

trainer.run(train_loader, max_epochs=epochs)
pbar.close()

Using GPUs is easier than with Ignite 🎊

PyTorch-Lightning_Demo
trainer = Trainer(gpus=[0,1,2,3])
trainer.fit(model)
Ignite_Demo
# inside the run function

if torch.cuda.is_available():
device = "cuda"
model.cuda(args.gpu)
else:
device = "cpu"

if args.distributed:
model = DistributedDataParallel(model, [args.gpu])

optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
evaluator = create_supervised_evaluator(model,
metrics={"accuracy": Accuracy(),
"nll": Loss(F.nll_loss)},
device=device)
  • It would be great if paper implementations were done in PyTorch Lightning, since it makes it clear where and what is being done...

Rich Logger Support

Just install with pip install [logger_you_want] and then use it with from pytorch_lightning.loggers import ~

Comparison of Activity Levels with Other Wrappers (as of 2020/7/10) ✔️

  • Issues and PRs are combined totals of open and closed
StarCommitsIssuesPRNotes
Lightning6.6k267113311241Currently the most active?
fast ai18,3k544410561552No commits since May (mature)
Ignite2.9k641517674Developed by the PyTorch team
catalyst2k1260226654Steady development?